{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Busecke et al. GRL EUC Preprocessing\n",
    "Notebook to convert/subset the data needed for the paper from the GFDL archive.\n",
    "> Note: Files are staged in `/work` to reconstruct files, either replicate my workflow to [stage files](https://github.com/jbusecke/guides/blob/master/gfdl_file_management.md) or replace `/work` with `/archive`, which will take substantially longer (files are likely stored on tape and need to be retrieved).\n",
    "\n",
    "On which domain is the EUC_shape detection performed? And is that described in the text?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TODO before resubmission:\n",
    "- Publish xarrayutils with momtools or standalone mom_tools (then change import below)\n",
    "- Some units are messed up (Johnson is still in cm/s)\n",
    "- Check the budget units in mv eof and budget plots (is it per kg?)\n",
    "- I added eofs package to env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For the dev\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import xarray as xr\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import os\n",
    "from os.path import join\n",
    "from dask.diagnostics import ProgressBar\n",
    "from xgcm.autogenerate import generate_grid_ds\n",
    "from xgcm import Grid\n",
    "from xarrayutils.xgcm_utils import xgcm_weighted_mean, dll_dist\n",
    "\n",
    "# There is a version of this from xarray...\n",
    "from xarrayutils import aggregate\n",
    "\n",
    "# #!!! at the point of publication This needs to be packaged into a proper module, with version that can be included in the requirments\n",
    "# also in `basic.py`\n",
    "import sys \n",
    "sys.path.append('/home/Julius.Busecke/projects/mom_tools')\n",
    "from mom_budgets import add_vertical_spacing, add_split_tendencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from global_oxygen.GRL_paper import mom_recombine_budgets, eq_mean, lat_slice, depth_slice, extract_shape, \\\n",
    "    extract_EUC_stats, xr_extract_contour, calculate_momentum_budget\n",
    "\n",
    "from global_oxygen.basic import load_obs_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "outdir = '../data/interim/GRL_Paper_clean/'\n",
    "if not os.path.exists(outdir):\n",
    "    os.mkdir(outdir)\n",
    "    \n",
    "def save_GRL_data(ds, name):\n",
    "    filename = join(outdir,name)\n",
    "    print('Saving to %s' %filename)\n",
    "    if os.path.exists(filename):\n",
    "        os.remove(filename)\n",
    "    ds.to_netcdf(filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = ['CM21deg', 'CM26']\n",
    "cmip = ['CMIP5_piControl', 'CMIP5_rcp85']\n",
    "obs = ['Bianchi', 'Johnson']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The step to convert all runs to zarr is missing\n",
    "# from mom_read import open_mom5_CM_ESM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load all preprocessed data\n",
    "There are a lot of steps that need to be documented for the publication but for now work from here.\n",
    "\n",
    "1. Import all the processed datasets needed\n",
    "2. Average them all together for the equatorial slices in one go"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prelim_roi(obj):\n",
    "    \"\"\"Rough region cut to avoid large files\"\"\"\n",
    "    obj = obj.copy()\n",
    "    x_roi = slice(-260, -60)\n",
    "    y_roi = slice(-20, 20)\n",
    "    z_roi = slice(0,1000)\n",
    "    roi = {\n",
    "        'xt_ocean': x_roi,\n",
    "        'xu_ocean': x_roi,\n",
    "        'yt_ocean': y_roi,\n",
    "        'yu_ocean': y_roi,\n",
    "        'st_ocean': z_roi,\n",
    "        'sw_ocean': z_roi,\n",
    "    }\n",
    "    for k,v in roi.items():\n",
    "        if k in obj.dims:\n",
    "            obj = obj.sel(**{k:v})\n",
    "        if 'run' in obj.dims:\n",
    "            obj = obj.sel(run='control')\n",
    "    return obj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vertical spacing for vertical vel cell is approximated!!! Use with caution\n",
      "Spacing for `dxte` is approximated!!! Use with caution\n",
      "Spacing for `dytn` is approximated!!! Use with caution\n",
      "Vertical spacing for vertical vel cell is approximated!!! Use with caution\n",
      "Spacing for `dxte` is approximated!!! Use with caution\n",
      "Spacing for `dytn` is approximated!!! Use with caution\n",
      "dict_keys(['CM21deg', 'CM26'])\n"
     ]
    }
   ],
   "source": [
    "# this takes forever due to a strange behaviour in `add_vertical_spacing`. I have raised an issue over at \n",
    "# xarray(https://github.com/pydata/xarray/issues/2867)\n",
    "models = ['CM21deg', 'CM26']\n",
    "kwargs_dict = dict()\n",
    "data = dict()\n",
    "\n",
    "for name in models:\n",
    "    data[name] = xr.open_zarr('../data/processed/%s.zarr' %name)\n",
    "    # Add the vertical spacing\n",
    "    data[name] = add_vertical_spacing(data[name])\n",
    "    data[name] = add_split_tendencies(data[name])\n",
    "#         data[name] = mom_recombine_budgets(data[name]) #moved to plotting to save space?\n",
    "    data[name] = prelim_roi(data[name])\n",
    "for k in data.keys():\n",
    "    kwargs_dict[k] = {'model_processing':True}\n",
    "print(data.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now for the observations\n",
    "# Gridded Obs\n",
    "obs_dict = load_obs_dict(fid_dict=['Bianchi']) #!!! This violates the reproducibility by having an absolute path\n",
    "ref = data['CM21deg'].isel(time=0).squeeze()\n",
    "obs = {k:v.load().interp_like(ref) for k,v in obs_dict.items()}\n",
    "\n",
    "# add model coords...\n",
    "for k in obs.keys():\n",
    "    for var in ['dxt', 'dyt', 'dzt']:\n",
    "        obs[k].coords[var] = ref[var]\n",
    "        \n",
    "# Velocity Obs\n",
    "obs['Johnson'] = xr.open_dataset('../data/processed/johnson_obs.nc')\n",
    "obs['tao'] = xr.open_dataset('../data/processed/tao_obs.nc')\n",
    "\n",
    "for k in obs.keys():\n",
    "    kwargs_dict[k] = {'model_processing':False}\n",
    "    data[k] = obs[k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CMIP equatorial slices (!!! need to reinterpolate also away from the equator for now just add)\n",
    "cmip_path = '../data/processed/'\n",
    "for name in ['CMIP5_tropical_combined_piControl', 'CMIP5_tropical_combined_rcp85']:\n",
    "    short_name = name.replace('_tropical_combined', '')\n",
    "    ds_cmip = xr.open_dataset(cmip_path+'%s.nc' %name).squeeze()\n",
    "    # cmip o2 values are given as [umol/m^3] (check) and thus need to be divided by ref density. I assume this is 1035 kg/m^3 (!!!check)\n",
    "    ds_cmip['o2'] = ds_cmip['o2'] / 1035\n",
    "    \n",
    "    #Create grid metrics\n",
    "    ds_full = generate_grid_ds(ds_cmip, {'X':'lon', 'Y':'lat'})\n",
    "    grid = Grid(ds_full)\n",
    "    dlonc = grid.diff(ds_full.lon_left, 'X', boundary='fill', fill_value=np.nan)\n",
    "    dlatc = grid.diff(ds_full.lat_left, 'Y', boundary='fill', fill_value=np.nan)\n",
    "    \n",
    "    ds_cmip.coords['dxt'], ds_cmip.coords['dyt'] = dll_dist(dlonc, dlatc, ds_full.lon, ds_full.lat)\n",
    "    # there is some jump, for now just cut it off (I can do this cleanly with the boundary dicont)\n",
    "    ds_cmip = ds_cmip.isel(lon=slice(0,-1))\n",
    "\n",
    "    kwargs_dict[short_name] = {'model_processing':False}\n",
    "    data[short_name] = ds_cmip.rename({'uo':'u', 'lev':'st_ocean', 'lon':'xt_ocean', 'lat':'yt_ocean'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dont keep all variables or the files for archival get bloated\n",
    "o2_budget_vars = [a for a in data['CM21deg'].data_vars if (('o2_' in a) or ('_o2' in a)) and ( a not in [\n",
    "    'o2_btf',\n",
    "    'o2_eta_smooth',\n",
    "    'o2_rivermix',\n",
    "    'o2_stf',\n",
    "    'o2_xland',\n",
    "    'o2_xlandinsert',\n",
    "    'o2_xflux_adv',\n",
    "    'o2_yflux_adv',\n",
    "    'o2_zflux_adv',\n",
    "    'o2_xflux_adv_recon',\n",
    "    'o2_yflux_adv_recon',\n",
    "    'o2_zflux_adv_recon',\n",
    "    'o2_advection_y_recon_tracer',\n",
    "    'o2_advection_z_recon_tracer',\n",
    "    'o2_advection_y_recon_vel',\n",
    "    'o2_advection_z_recon_vel',\n",
    "    'o2_xflux_submeso',\n",
    "    'o2_yflux_submeso',\n",
    "    'o2_zflux_submeso',\n",
    "])] + ['jo2']\n",
    "keep_vars = ['o2', 'u', 'v', 'pot_rho_0', 'dzt', 'NINO34'] + o2_budget_vars\n",
    "\n",
    "def xr_keep(obj, varlist=keep_vars):\n",
    "    obj = obj.copy()\n",
    "    keep_vars = [a for a in varlist if a in obj.data_vars]\n",
    "    return obj.get(keep_vars)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computing the equatorial average and depth/lat slices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving to CM21deg_eq_mean.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM21deg_eq_mean.nc\n",
      "[########################################] | 100% Completed | 56.9s\n",
      "saving to CM26_eq_mean.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM26_eq_mean.nc\n",
      "[########################################] | 100% Completed | 21min 10.0s\n",
      "saving to Bianchi_eq_mean.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/Bianchi_eq_mean.nc\n",
      "[########################################] | 100% Completed |  0.2s\n",
      "saving to Johnson_eq_mean.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/Johnson_eq_mean.nc\n",
      "saving to CMIP5_piControl_eq_mean.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/CMIP5_piControl_eq_mean.nc\n",
      "saving to CMIP5_rcp85_eq_mean.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/CMIP5_rcp85_eq_mean.nc\n"
     ]
    }
   ],
   "source": [
    "# Equatorial slices (average from 1S-1N) - with full time variability retained...\n",
    "from global_oxygen.GRL_paper import interp_all, eq_mean\n",
    "for k, ds in data.items():\n",
    "    ds = ds.copy()\n",
    "    filename = '%s_eq_mean.nc' %k\n",
    "    if k not in ['tao']:\n",
    "        ds_eq = prelim_roi(eq_mean(xr_keep(ds), **kwargs_dict[k]))\n",
    "        print('saving to %s' %filename)\n",
    "        with ProgressBar():\n",
    "            save_GRL_data(ds_eq, filename)\n",
    "del ds_eq, k, ds, var"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lat slices along Johnson observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CM21deg\n",
      "10\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM21deg_lat_slice.nc\n",
      "[########################################] | 100% Completed |  8.6s\n",
      "CM26\n",
      "10\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM26_lat_slice.nc\n",
      "[########################################] | 100% Completed |  2min 30.6s\n",
      "Bianchi\n",
      "10\n",
      "Saving to ../data/interim/GRL_Paper_clean/Bianchi_lat_slice.nc\n",
      "[########################################] | 100% Completed |  0.4s\n",
      "Johnson\n",
      "10\n",
      "Saving to ../data/interim/GRL_Paper_clean/Johnson_lat_slice.nc\n",
      "[########################################] | 100% Completed |  0.1s\n",
      "CMIP5_piControl\n",
      "10\n",
      "Saving to ../data/interim/GRL_Paper_clean/CMIP5_piControl_lat_slice.nc\n",
      "[########################################] | 100% Completed |  7.9s\n",
      "CMIP5_rcp85\n",
      "10\n",
      "Saving to ../data/interim/GRL_Paper_clean/CMIP5_rcp85_lat_slice.nc\n",
      "[########################################] | 100% Completed | 24.5s\n"
     ]
    }
   ],
   "source": [
    "for kk, ds in data.items():\n",
    "    if kk not in ['tao', 'MIMOC']: #MIMOC has neither o2 nor u, and tao lat slices are not very useful?\n",
    "        print(kk)\n",
    "        filename = '%s_lat_slice.nc' %kk\n",
    "        ref = data['Johnson'].u\n",
    "        ds = ds.copy()\n",
    "        ds = xr_keep(ds, ['u', 'o2'])\n",
    "        for dd in ['xt_ocean', 'xu_ocean']:\n",
    "            if dd in list(ds.dims):\n",
    "                ds = ds.chunk({dd:-1})\n",
    "\n",
    "        ds_lat_slice = (lat_slice(prelim_roi(ds), ref, **kwargs_dict[kk])) #\n",
    "\n",
    "        if 'time' in ds_lat_slice.dims:\n",
    "            ds_lat_slice = ds_lat_slice.mean('time')\n",
    "        if 'month' in ds_lat_slice.dims:\n",
    "            ds_lat_slice = ds_lat_slice.mean('month')\n",
    "        print(len(ds_lat_slice.xt_ocean))\n",
    "        with ProgressBar():\n",
    "            save_GRL_data(ds_lat_slice, filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Slices at constant depth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "saving depth slice to CM21deg_depth_slice.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM21deg_depth_slice.nc\n",
      "[########################################] | 100% Completed |  9.4s\n",
      "saving depth slice to CM26_depth_slice.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM26_depth_slice.nc\n",
      "[########################################] | 100% Completed |  1min 52.7s\n",
      "saving depth slice to Bianchi_depth_slice.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/Bianchi_depth_slice.nc\n",
      "[########################################] | 100% Completed |  0.2s\n"
     ]
    }
   ],
   "source": [
    "# Depth slices\n",
    "for k, ds in data.items():\n",
    "    filename = '%s_depth_slice.nc' %k\n",
    "    if k in ['CM21deg', 'CM26', 'Bianchi']:\n",
    "        ds_slice = depth_slice(xr_keep(ds, ['o2', 'u']), st_ocean=250, **kwargs_dict[k]) #Here I only need the o2\n",
    "        # Bin averaging the high res model\n",
    "        if 'CM26' in k:\n",
    "            temp = xr.Dataset()\n",
    "            for var in ds_slice.data_vars:\n",
    "                with ProgressBar():\n",
    "                    temp[var] = aggregate(ds_slice[var].chunk({'xt_ocean':100, 'yt_ocean':100}),\n",
    "                                           [('xt_ocean', 20),('yt_ocean',20)])\n",
    "                \n",
    "            ds_slice = temp                \n",
    "        print('saving depth slice to %s' %filename)\n",
    "        with ProgressBar():\n",
    "            save_GRL_data(ds_slice, filename) \n",
    "            del ds_slice, k, ds, filename"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Extract the EUC shape parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving to CM21deg_euc_shape.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM21deg_euc_shape.nc\n",
      "Saving to CM26_euc_shape.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM26_euc_shape.nc\n",
      "Saving to Johnson_euc_shape.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/Johnson_euc_shape.nc\n",
      "Something went wrong during interpolation\n",
      "x [230.0, 235.0]\n",
      "y [0.22620001, 0.2]\n",
      "c 0.2\n",
      "Something went wrong during interpolation\n",
      "x [260.0, 265.0]\n",
      "y [0.21700001, 0.2]\n",
      "c 0.2\n",
      "Something went wrong during interpolation\n",
      "x [140.0, 145.0]\n",
      "y [0.2, 0.24416131]\n",
      "c 0.2\n",
      "Saving to tao_euc_shape.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/tao_euc_shape.nc\n",
      "Saving to CMIP5_piControl_euc_shape.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/CMIP5_piControl_euc_shape.nc\n",
      "Saving to CMIP5_rcp85_euc_shape.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/CMIP5_rcp85_euc_shape.nc\n"
     ]
    }
   ],
   "source": [
    "# extract EUC stats\n",
    "EUC_boundary = 0.2\n",
    "def process_and_save_euc_shape(data):\n",
    "    data = data.copy()\n",
    "    for k,v in data.items():\n",
    "        if 'u' in v.data_vars:\n",
    "            if k not in ['tao']:\n",
    "                    if 'xt_ocean' in v.u.dims:\n",
    "                        v = v.rename({'xt_ocean':'xu_ocean', 'yt_ocean':'yu_ocean', 'dyt':'dyu'})\n",
    "                    euc_shape = extract_EUC_stats(v, EUC_boundary)\n",
    "            else:\n",
    "                # tao has no lat dimesion, so it has to be processed differently\n",
    "                temp = extract_shape(v.u, EUC_boundary)\n",
    "                euc_shape = xr.Dataset({'core_value_z':temp.value,\n",
    "                                                'core_pos_z':temp.position,\n",
    "                                                'thickness':temp.extent,\n",
    "                                                'upper':temp.inner,\n",
    "                                                'lower':temp.outer,\n",
    "                                                })\n",
    "            euc_shape.attrs['EUC_boundary'] = EUC_boundary\n",
    "\n",
    "            filename = '%s_euc_shape.nc' %k\n",
    "            print('Saving to %s' %filename)\n",
    "            save_GRL_data(euc_shape, filename)\n",
    "\n",
    "process_and_save_euc_shape(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extract the omz shape parameter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CM21deg\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM21deg_omz_shape.nc\n",
      "CM26\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM26_omz_shape.nc\n",
      "CMIP5_piControl\n",
      "Saving to ../data/interim/GRL_Paper_clean/CMIP5_piControl_omz_shape.nc\n",
      "CMIP5_rcp85\n",
      "Saving to ../data/interim/GRL_Paper_clean/CMIP5_rcp85_omz_shape.nc\n",
      "Bianchi\n",
      "Saving to ../data/interim/GRL_Paper_clean/Bianchi_omz_shape.nc\n",
      "Johnson\n"
     ]
    }
   ],
   "source": [
    "names = models + cmip + obs\n",
    "\n",
    "ds_eq = dict()\n",
    "for k in names:\n",
    "    ds_eq[k] = xr.open_dataset(join(outdir, '%s_eq_mean.nc' %k))\n",
    "    \n",
    "# Average the oxygen observations over month of the year.\n",
    "ds_eq['Bianchi'] = ds_eq['Bianchi'].mean('month') \n",
    "\n",
    "tilt_threshold = 80 # in mumol\n",
    "tilt_roi = dict(xt_ocean=slice(-230, -90), st_ocean=slice(0,500))\n",
    "\n",
    "omz_shape = dict()\n",
    "for k, v in ds_eq.items():\n",
    "    print(k)\n",
    "    if 'o2' in v.data_vars:\n",
    "        omz_shape = xr_extract_contour(v.sel(**tilt_roi).o2.load(),\n",
    "                                             tilt_threshold * 1e-6,\n",
    "                                             'xt_ocean',\n",
    "                                             'st_ocean',\n",
    "                                             showplot=False).to_dataset(name='depth')\n",
    "        filename = '%s_omz_shape.nc' %k\n",
    "        save_GRL_data(omz_shape, filename)\n",
    "del omz_shape, k, ds, filename"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Momentum budget and mean map slices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving to ../data/interim/GRL_Paper_clean/CM21deg_hor_momentum_budget.nc\n",
      "[########################################] | 100% Completed |  9.7s\n",
      "Saving to ../data/interim/GRL_Paper_clean/CM26_hor_momentum_budget.nc\n",
      "[########################################] | 100% Completed |  3min 17.1s\n"
     ]
    }
   ],
   "source": [
    "for k in models:\n",
    "    # calculate momentum budget and cut to roi\n",
    "    mom_budget = calculate_momentum_budget(data[k]).sel(yu_ocean=slice(-10, 10),\n",
    "                                                        yt_ocean=slice(-10, 10),\n",
    "                                                        st_ocean=slice(250, 350),\n",
    "                                                        sw_ocean=slice(250, 350))\n",
    "    grid = Grid(mom_budget)\n",
    "    # average budget over depth.\n",
    "    mom_budget = xgcm_weighted_mean(grid, mom_budget, 'Z', ['dzt', 'dzu'])\n",
    "    # average in time\n",
    "    mom_budget = mom_budget.mean('time')\n",
    "    filename = '%s_hor_momentum_budget.nc' %k\n",
    "    with ProgressBar():\n",
    "        save_GRL_data(mom_budget, filename)\n",
    "    del mom_budget, grid, filename"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EOF Stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No `eddy_combined` generated for o2. dummy added.\n"
     ]
    }
   ],
   "source": [
    "eof_roi = dict(st_ocean=slice(0,500), xt_ocean=slice(-235,-80))\n",
    "\n",
    "eq_mean_split = {k:mom_recombine_budgets(xr.open_dataset(join(outdir,'%s_eq_mean.nc' %k))) for k in models}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_u_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_advection_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_vertical_diff_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_advection_eddy_recon_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_advection_recon_w_diff_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_eddy_combined_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_u_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_advection_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_vertical_diff_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_advection_eddy_recon_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_advection_recon_w_diff_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/single_eof_z_o2_eddy_combined_CM26.nc\n"
     ]
    }
   ],
   "source": [
    "#!!!\n",
    "# Did I mention that and how the seasons are removed from the signal?\n",
    "from global_oxygen.GRL_paper import remove_seas_simple, eof_wrapper\n",
    "\n",
    "eof_vars = ['o2', 'u', 'o2_advection', 'o2_vertical_diff', 'o2_advection_eddy_recon', 'o2_advection_recon_w_diff', 'o2_eddy_combined']\n",
    "\n",
    "# single variable EOFs\n",
    "eof_z = dict()\n",
    "for k, ds in eq_mean_split.items():\n",
    "    eof_z[k] = dict()\n",
    "    # Remove seasonal cycle\n",
    "    ds = (remove_seas_simple(ds)+ds.mean('time')).load()\n",
    "    # Focus on our roi\n",
    "    ds = ds.sel(**eof_roi)\n",
    "    ds = ds.transpose('time','st_ocean', 'xt_ocean',) # This needs to be done in the eof wrapper\n",
    "    for var in eof_vars:\n",
    "        filename = 'single_eof_z_%s_%s.nc' %(var,k)\n",
    "        eof_z[k][var] = eof_wrapper(ds, fields=[var], weight_fields=['dzt'], normalize=False)\n",
    "        save_GRL_data(eof_z[k][var], filename)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_u_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_u_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_o2_advection_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_o2_advection_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_o2_vertical_diff_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_o2_vertical_diff_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_o2_advection_eddy_recon_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_o2_advection_eddy_recon_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_o2_advection_recon_w_diff_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_o2_advection_recon_w_diff_CM26.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_o2_eddy_combined_CM21deg.nc\n",
      "Saving to ../data/interim/GRL_Paper_clean/multi_eof_z_o2_o2_eddy_combined_CM26.nc\n"
     ]
    }
   ],
   "source": [
    "# multivariate EOFs\n",
    "for var2 in [e for e in eof_vars if e != 'o2']:\n",
    "    for k, ds in eq_mean_split.items():\n",
    "        filename = 'multi_eof_z_o2_%s_%s.nc' %(var2,k)\n",
    "        # Remove seasonal cycle\n",
    "        ds = (remove_seas_simple(ds)+ds.mean('time')).load()\n",
    "        ds = ds.sel(**eof_roi)\n",
    "        ds = ds.transpose('time','st_ocean', 'xt_ocean',) # This needs to be done in the eof wrapper\n",
    "        mv_eof_z = eof_wrapper(ds, fields=['o2', var2], weight_fields=['dzt', 'dzt'], n=3)\n",
    "        save_GRL_data(mv_eof_z, filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
